#include "AscendGaussianBlur.h"
#include <sys/time.h>
AscendGaussianBlur::AscendGaussianBlur(int _src_w, int _src_h, int _channel, int kernel_size, double sigma) {
    std::map<AscendString, AscendString> global_options = {
            {ge::ir_option::SOC_VERSION, "Ascend310"},
    };
    auto ret = aclgrphBuildInitialize(global_options);
    src_h = _src_h;
    src_w = _src_w;
    channel = _channel;

    TensorDesc data_desc(ge::Shape({1, src_h, src_w, channel}), FORMAT_ND, DT_FLOAT16);
    auto data = op::Data("data").set_attr_index(0);
    data.update_input_desc_x(data_desc);
    data.update_output_desc_y(data_desc);


    Mat gaussianKernel1D = getGaussianKernel(kernel_size, sigma, CV_32F);
    vector<Mat> gaussianKernels(3);
    gaussianKernels[0] = gaussianKernel1D * gaussianKernel1D.t();
    gaussianKernels[1] = gaussianKernels[0].clone();
    gaussianKernels[2] = gaussianKernels[0].clone();
    Mat gaussianKernel;
    merge(gaussianKernels, gaussianKernel);
    gaussianKernel.convertTo(gaussianKernel, CV_16FC3);

    TensorDesc blur_kernel_weight_desc(ge::Shape({kernel_size, kernel_size, 1, channel}), FORMAT_ND, DT_FLOAT16);
    Tensor blur_kernel_weight(blur_kernel_weight_desc,
                              gaussianKernel.data, gaussianKernel.rows * gaussianKernel.cols * gaussianKernel.channels() * 2);
    auto blur_kernel = op::Const("blur/kernel").set_attr_value(blur_kernel_weight);

    int stride = 1;
    int padding = ((src_h - stride) - src_h + kernel_size) / 2;

    auto blur = op::Conv2D("blur")
            .set_input_x(data)
            .set_input_filter(blur_kernel)
            .set_attr_strides({ stride, stride, stride, stride })
            .set_attr_pads({ padding, padding, padding, padding })
            .set_attr_groups(3);

    TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NHWC, DT_FLOAT16);
    TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_HWCN, DT_FLOAT16);
    TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NHWC, DT_FLOAT16);
    blur.update_input_desc_x(conv2d_input_desc_x);
    blur.update_input_desc_filter(conv2d_input_desc_filter);
    blur.update_output_desc_y(conv2d_output_desc_y);


    graph = Graph("BlurGraph");
    std::vector<Operator> inputs{ data };
    std::vector<Operator> outputs{ blur };
    graph.SetInputs(inputs).SetOutputs(outputs);


    ModelBufferData model;
    std::map<AscendString, AscendString> options;

    ret = aclgrphBuildModel(graph, options, model);
    if (ret == GRAPH_SUCCESS) {
        cout << "Build Model SUCCESS!" << endl;
    }
    else {
        cout << "Build Model Failed!"<< ret << endl;
    }
//    ret = aclgrphSaveModel("graph", model);
//    if (ret == GRAPH_SUCCESS) {
//        cout << "Save Offline Model SUCCESS!" << endl;
//    }
//    else {
//        cout << "Save Offline Model Failed!" << endl;
//    }

    ret = aclrtCreateContext(&m_context, 0);
    if (ret != ACL_ERROR_NONE) {
        cout << "Failed to set current context, ret = " << ret << endl;
    }
    cout << "Create context successfully" << endl;
    ret = aclrtSetCurrentContext(m_context);
    if (ret != ACL_ERROR_NONE) {
        cout << "Failed to set current context, ret = " << ret << endl;
    }
    cout << "set context successfully" << endl;
    ret = aclrtCreateStream(&m_stream);
    if (ret != ACL_ERROR_NONE) {
        cout << "Failed to create stream, ret = " << ret << endl;
    }
    cout << "Create stream successfully" << endl;

    if (m_modelProcess == nullptr) {
        m_modelProcess = std::make_shared<ModelProcess>(0, "");
    }
    aclError r = m_modelProcess->Init(model.data.get(), model.length);

    if (r != ACL_ERROR_NONE) {
        cout << "Failed to initialize m_modelProcess, ret = " << r << endl;
    }
    m_modelDesc = m_modelProcess->GetModelDesc();
    //get model input description and malloc them
    size_t inputSize = aclmdlGetNumInputs(m_modelDesc);
    for (size_t i = 0; i < inputSize; i++) {
        size_t bufferSize = aclmdlGetInputSizeByIndex(m_modelDesc, i);
        void *inputBuffer = nullptr;
        r = aclrtMalloc(&inputBuffer, bufferSize, ACL_MEM_MALLOC_NORMAL_ONLY);
        if (r != ACL_ERROR_NONE) {
            cout << "Failed to malloc buffer, ret = " << ret << endl;
        }
        inputBuffers.push_back(inputBuffer);
        inputSizes.push_back(bufferSize);
    }
    //get model output description and malloc them
    size_t outputSize = aclmdlGetNumOutputs(m_modelDesc);
    for (size_t i = 0; i < outputSize; i++) {
        size_t bufferSize = aclmdlGetOutputSizeByIndex(m_modelDesc, i);
        void *outputBuffer = nullptr;
        r = aclrtMalloc(&outputBuffer, bufferSize, ACL_MEM_MALLOC_NORMAL_ONLY);
        if (r != ACL_ERROR_NONE) {
            cout << "Failed to malloc buffer, ret = " << ret << endl;
        }
        outputBuffers.push_back(outputBuffer);
        outputSizes.push_back(bufferSize);
    }
}

AscendGaussianBlur::~AscendGaussianBlur() {
    aclgrphBuildFinalize();
    m_modelProcess = nullptr;
    aclError ret = aclrtSynchronizeStream(m_stream);
    if (ret != ACL_ERROR_NONE) {
        cout << "some tasks in stream not done, ret = " << ret <<endl;
    }
    cout << "all tasks in stream done" << endl;
    ret = aclrtDestroyStream(m_stream);
    if (ret != ACL_ERROR_NONE) {
        cout << "Destroy Stream faild, ret = " << ret <<endl;
    }
    cout << "Destroy Stream successfully" << endl;
    ret = aclrtDestroyContext(m_context);
    if (ret != ACL_ERROR_NONE) {
        cout << "Destroy Context faild, ret = " << ret <<endl;
    }
    cout << "Destroy Context successfully" << endl;
}

Mat AscendGaussianBlur::run(Mat& src) {
    struct timeval start;
    struct timeval end;
    Mat temp;
    gettimeofday(&start,NULL);
    Mat dst = Mat(src.size(), CV_16FC3, Scalar(0,0,0));
    gettimeofday(&end,NULL);
    cout<<"--create dst time : "<<(end.tv_sec-start.tv_sec)*1000+(end.tv_usec-start.tv_usec) / 1000.0 <<"ms"<<endl;

    if(src.cols == src_w && src.rows == src_h){
        gettimeofday(&start,NULL);
        src.convertTo(temp, CV_16FC3);
        gettimeofday(&end,NULL);
        cout<<"--convert to fp16 time : "<<(end.tv_sec-start.tv_sec)*1000+(end.tv_usec-start.tv_usec) / 1000.0 <<"ms"<<endl;
        gettimeofday(&start,NULL);
        aclError ret = aclrtMemcpy(inputBuffers[0], inputSizes[0], (void*)temp.data, inputSizes[0], ACL_MEMCPY_HOST_TO_DEVICE);
        if (ret != ACL_ERROR_NONE) {
            cout<<"copy data to device faild.ret = "<< ret <<endl;
        }
        gettimeofday(&end,NULL);
        cout<<"--copy to device time : "<<(end.tv_sec-start.tv_sec)*1000+(end.tv_usec-start.tv_usec) / 1000.0 <<"ms"<<endl;
        gettimeofday(&start,NULL);
        ret = m_modelProcess->ModelInference(inputBuffers, inputSizes, outputBuffers, outputSizes);
        if (ret != ACL_ERROR_NONE) {
            cout<<"model run faild.ret = "<< ret <<endl;
        }
        gettimeofday(&end,NULL);
        cout<<"--model infer time : "<<(end.tv_sec-start.tv_sec)*1000+(end.tv_usec-start.tv_usec) / 1000.0 <<"ms"<<endl;
        gettimeofday(&start,NULL);
        ret = aclrtMemcpy((void*)dst.data, outputSizes[0], outputBuffers[0], outputSizes[0], ACL_MEMCPY_DEVICE_TO_HOST);
        if (ret != ACL_ERROR_NONE) {
            cout<<"copy data to device faild.ret = "<< ret <<endl;
        }
        gettimeofday(&end,NULL);
        cout<<"--copy to host time : "<<(end.tv_sec-start.tv_sec)*1000+(end.tv_usec-start.tv_usec) / 1000.0 <<"ms"<<endl;
    }else{
        cout << "input size not eq to define size" << endl;
    }
    return dst;
}
